TFRecords (Part 2): Reading and training models with Tfrecords.
In the first part, I showed how to convert a dataset of medical images and its target value into tfrecords. Here, I will be showing how to read tfrecords and also how to train an ML model in tensorflow using tfrecords. For this, I will be used the tfrecords I created from the first part.
Reading Tfrecords
Tfrecords store data in binary format for fast and easy access. To read an example from a tfrecord file, we first need to create a function to parse the example from the file and then use tf.data.TFRecordDataset
object and the parsing function to read the file. This is a very easy and straightforward process.
In the first part, we put two features (image and target) from our dataset in the tfrecord files. First, we create a feature_description
dictionary to read each feature. The keys in the dictionary must be the same as the keys used to store the features in the tfrecord files. tf.io.FixedLenFeature
reads the features and stores them in the data type given. Here, tf.string
is used for the image feature since it is a tensor and tf.int64
is used for the target value since it’s an integer.
def parse_tfrecord_fn(example):
feature_description = {
"image": tf.io.FixedLenFeature([], tf.string),
"target": tf.io.FixedLenFeature([], tf.int64)
}
example = tf.io.parse_single_example(example, feature_description)
example["image"] = tf.io.decode_png(example["image"])
return example
The parse_tfrecord_fn
reads each example and maps the features to the corresponding data type. The tf.io.decode_png
function does the opposite of what the tf.io.encode_png
function introduced in Part 1 does. It converts the bytelist back to an image tensor using png compression.
After creating the parsing function, it’s fairly easy to read the files. First, use tf.data.TFRecordDataset
to read the tfrecord file from the path and call the map method using the parse_tfrecord_fn
as the argument.
raw_dataset = tf.data.TFRecordDataset(".../tfrecords/tfrecord_0-1000.tfrec")
parsed_dataset = raw_dataset.map(parse_tfrecord_fn)
Now that the dataset is parsed, to read an example from it, we use the take method, specifying the number of examples we want to read as the argument. This is similar to how we use df.head()
in pandas.
for example in parsed_dataset.take(5):
for key in example.keys():
print(f"{key}: {type(key)}")
print(f"Image shape: {example['image'].shape}")
plt.figure(figsize=(7, 7))
plt.imshow(example["image"].numpy())
plt.show()
parsed_dataset.take(5)
takes five examples from the dataset. You can access the files using a for loop or by passing it into a list using list(parsed_dataset.take(5))
Either way, the examples are stored as a dictionary and the features and values are a key-value pair.
Training a tensorflow model using tfrecords
Training a tensorflow/Keras deep learning model using tfrecords is very easy. First, define the model.
Since the dataset consists of images and their target value, I have chosen the vgg-16 model from tf.keras.applications
and I will be applying transfer learning, so the model weights are set to imagenet
and trainability of the vgg16
is set to false. Here’s how it’s done using the Keras functional API.
num_classes = 1
input_image = tf.keras.Input(shape=(224, 224, 3), name='image')
# Load the VGG16 model
vgg16 = tf.keras.applications.VGG16(weights='imagenet',
include_top=False,
input_shape=(224, 224, 3))
vgg16.trainable = False
x = vgg16(input_image)
y = tf.keras.layers.Flatten()(x)
z = tf.keras.layers.Dense(128, activation='relu')(y)
output = tf.keras.layers.Dense(num_classes, activation='sigmoid')(z)
# Create the model
model = tf.keras.Model(inputs=[input_image], outputs=output)
model.summary()
Line 2 of the code above defines the Input of the model. Here, we are using a tensor of shape (224, 224, 3)
which is the size of the image and assigning the input the name “image”. The name would help tensorflow know what feature from the tfrecord file to use as input. This is particularly helpful when you have more than one feature.
Next, we define a parsing function that will convert each example in a way that tensorflow will understand when trying to fit the model.
def parse_example(example):
feature_description = {
"image": tf.io.FixedLenFeature([], tf.string),
"target": tf.io.FixedLenFeature([], tf.int64)
}
example = tf.io.parse_single_example(example, feature_description)
example["image"] = tf.io.decode_png(example["image"])
X = dict()
X['image'] = tf.image.grayscale_to_rgb(example['image'])
y = example['target']
return X, y
Here, the function returns a tuple (X, y)
where X
is the feature and y
the target. The first 6 lines of the function are the same as the one we used earlier. The main difference is the definition of a dictionary assigned to X
. The next line takes the image tensor from example['image']
and converts it from shape (224, 224, 1)
to (224, 224, 3)
.( You can read the documentation of how tf.image.grayscale_to_rgb
works for more details on that). The converted tensor is assigned to X['image']
. X
is the feature dictionary we want to return along with the target and image
is used as the key here to match the name of the Input defined in the model above.
Now, all we have to do is read the files. Since we created more than one tfrecord file. If we want to train the model on all of the data available, we have to read all the files. Thankfully, tf.data.TFRecordDataset
accepts both a single file path and a list of file paths as input arguments.
tf_files = glob.glob('/kaggle/input/notebook9fa8840645/tfrecords/*.tfrec')
dataset = tf.data.TFRecordDataset(tf_files)
dataset = dataset.map(parse_example)
This time around, when calling the map method, we use the use parsing function parse_example
as input.
Now, to train the model using the dataset, we call the batch
method on the dataset and pass the output as the argument to model.fit
dataset = dataset.shuffle(buffer_size=1024).batch(32)
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])
# Train the model
model.fit(dataset, epochs=10, validation_data=val_dataset)
Here, I called the shuffle
method first to randomly select 1024 examples and then using the batch
method to pick 32 of those.
How about a Train/Val/Test Split?
tf.data.TFRecordDataset
does not provide a straightforward way to perform a train-test split. But it’s pretty much straightforward if you know the number of examples in the dataset. Tensorflow also provides a function you can use to get the size but this might not work. You can check if the size of the dataset is known using the following block of code.
cardinality = tf.data.experimental.cardinality(dataset)
print((cardinality == tf.data.experimental.UNKNOWN_CARDINALITY).numpy())
If this returns True
then the size of the dataset is not known. If it returns False
then it is known and you can get the size using tf.data.experimental.cardinality
or dataset.cardinality
.
Using the function below, you can create a train-val-test split.
def train_val_test_split(dataset,
dataset_size,
train_split=0.8,
val_split=0.1,
test_split=0.1,
shuffle=True,
shuffle_size=10000):
assert (train_split + test_split + val_split) == 1
if shuffle:
# Specify seed to always have the same split distribution between runs
dataset = dataset.shuffle(shuffle_size, seed=12)
train_size = int(train_split * dataset_size)
val_size = int(val_split * dataset_size)
train_dataset = dataset.take(train_size)
val_dataset = dataset.skip(train_size).take(val_size)
test_dataset = dataset.skip(train_size).skip(val_size)
return train_dataset, val_dataset, test_dataset
by sodipe🌚 on February 23, 2023.